{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6cr-JUn_1wRw"
      },
      "source": [
        "# MPLP: MNIST 20 step, 8 shot\n",
        "\n",
        "\n",
        "Copyright 2020 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "L8F2pl0H2Szr"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import io\n",
        "import numpy as np\n",
        "import glob\n",
        "\n",
        "import tensorflow.compat.v2 as tf\n",
        "tf.enable_v2_behavior()\n",
        "tf.get_logger().setLevel('ERROR')\n",
        "\n",
        "import matplotlib.pyplot as plt # visualization\n",
        "\n",
        "import numpy as onp\n",
        "from collections import defaultdict\n",
        "\n",
        "import itertools\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow.compat.v2 as tf\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import IPython.display as display\n",
        "from PIL import Image\n",
        "import numpy as np\n",
        "import os\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QBo8SVVqmLGU"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=mplp\u0026subdirectory=mplp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cAm7UBfbjJz6"
      },
      "outputs": [],
      "source": [
        "# symlink for saved models.\n",
        "!ln -s src/mplp/mplp/savedmodels savedmodels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "hEa0rSml_AC3"
      },
      "outputs": [],
      "source": [
        "\n",
        "from mplp.tf_layers import MPDense\n",
        "from mplp.tf_layers import MPActivation\n",
        "from mplp.tf_layers import MPSoftmax\n",
        "from mplp.tf_layers import MPCrossEntropyLoss\n",
        "from mplp.tf_layers import MPNetwork\n",
        "from mplp.util import SamplePool\n",
        "from mplp.training import TrainingRegime\n",
        "from mplp.core import GRUBlock\n",
        "from mplp.core import OutStandardizer\n",
        "from mplp.core import StandardizeInputsAndStates"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "jKFOzgoO21hv"
      },
      "outputs": [],
      "source": [
        "\n",
        "def load_mnist():\n",
        "  (x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
        "  x_train, x_test = x_train / 255.0, x_test / 255.0\n",
        "\n",
        "  x_train = np.array(x_train).astype(np.float32)\n",
        "  y_train = np.array(y_train).astype(np.int64)\n",
        "  x_test = np.array(x_test).astype(np.float32)\n",
        "  y_test = np.array(y_test).astype(np.int64)\n",
        "  return (x_train, y_train),(x_test, y_test)\n",
        "\n",
        "(x_train, y_train),(x_test, y_test) = load_mnist()\n",
        "\n",
        "def one_hottify(dset):\n",
        "  one_hottified = onp.zeros([dset.shape[0], 10], dtype=onp.float32)\n",
        "  one_hottified[onp.arange(dset.size), dset] = 1.0\n",
        "  return one_hottified\n",
        "\n",
        "def MNIST_generator(inputs, labels):\n",
        "  while True:\n",
        "    idx = pyrandom.randrange(len(inputs))\n",
        "    x = inputs[idx]\n",
        "    y = labels[idx]\n",
        "    yield x, y\n",
        "\n",
        "PIC_L = 12\n",
        "\n",
        "x_train = tf.image.resize(onp.expand_dims(x_train, -1), size=(PIC_L, PIC_L)).numpy()\n",
        "x_test = tf.image.resize(onp.expand_dims(x_test, -1), size=(PIC_L, PIC_L)).numpy()\n",
        "y_train = one_hottify(y_train)\n",
        "y_test = one_hottify(y_test)\n",
        "\n",
        "# standardize inputs:\n",
        "train_mean, train_std = x_train.mean(), x_train.std()\n",
        "\n",
        "x_train = (x_train - train_mean) / train_std\n",
        "x_test = (x_test - train_mean) / train_std\n",
        "\n",
        "x_train = x_train.reshape([-1, PIC_L * PIC_L])\n",
        "x_test = x_test.reshape([-1, PIC_L * PIC_L])\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "lzbQbrDy3UW1"
      },
      "outputs": [],
      "source": [
        "outer_batch_size = 1\n",
        "inner_batch_size = 8\n",
        "loop_steps = 20\n",
        "img_h = img_w = 12\n",
        "\n",
        "def resize(x, y):\n",
        "  x = tf.image.resize(x, [img_h, img_w])\n",
        "  x = tf.reshape(x, [-1, img_h * img_w])\n",
        "  return x, y\n",
        "\n",
        "resized_mnist_train = tf.data.Dataset.from_tensor_slices(\n",
        "    (x_train, y_train))#.map(resize)\n",
        "\n",
        "with tf.device(\"/cpu:0\"):\n",
        "  dataset = resized_mnist_train.batch(inner_batch_size).batch(loop_steps).batch(outer_batch_size).shuffle(1000).cache().repeat()\n",
        "  ds_iter = iter(dataset)\n",
        "  xval_ds = resized_mnist_train.batch(inner_batch_size).batch(outer_batch_size).shuffle(1000).cache().repeat()\n",
        "  xval_ds_iter = iter(xval_ds)\n",
        "\n",
        "  test_ds = tf.data.Dataset.from_generator(\n",
        "            lambda: MNIST_generator(x_test, y_test),\n",
        "            output_types=(onp.float32, onp.float32))\n",
        "  test_ds = test_ds.batch(inner_batch_size)\n",
        "  test_ds_iter = iter(test_ds)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "B4n7iH6e2lYz"
      },
      "outputs": [],
      "source": [
        "message_size = 4\n",
        "stateful = True\n",
        "stateful_hidden_n = 7\n",
        "\n",
        "import functools\n",
        "\n",
        "l0 = PIC_L * PIC_L\n",
        "def init_network(activation, sizes):\n",
        "  network = MPNetwork([MPDense(sizes[1], stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "                    MPActivation(activation, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "                    MPDense(10, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "                    MPSoftmax(stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "                    ],\n",
        "                    MPCrossEntropyLoss(message_size, stateful=stateful, stateful_hidden_n=stateful_hidden_n))\n",
        "  # create shared networks!\n",
        "  shared_params = {\n",
        "      \"W_net\": GRUBlock(x_dim=2 + message_size,\n",
        "                        carry_n=1+message_size+stateful_hidden_n),\n",
        "      \"b_net\": GRUBlock(x_dim=2 + message_size,\n",
        "                        carry_n=1+message_size+stateful_hidden_n),\n",
        "#      \"W_out_std\": OutStandardizer(scale_init_val=0.05),\n",
        "#      \"b_out_std\": OutStandardizer(scale_init_val=0.05),\n",
        "#      \"W_in_std\": StandardizeInputsAndStates([\"W_b\", \"W_in\"]),\n",
        "#      \"b_in_std\": StandardizeInputsAndStates([\"b_b\", \"b_in\"]),\n",
        "                   }\n",
        "\n",
        "  network.setup(in_dim=sizes[0], message_size=message_size, \n",
        "                inner_batch_size=inner_batch_size,\n",
        "                shared_params=shared_params)\n",
        "  return network\n",
        "\n",
        "network = init_network(tf.sigmoid, (l0, 50))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "IKe2E8miWnpM"
      },
      "outputs": [],
      "source": [
        "def create_new_p_fw():\n",
        "  return network.init(inner_batch_size)\n",
        "\n",
        "POOL_SIZE = 16\n",
        "\n",
        "pool = SamplePool(ps=tf.stack(\n",
        "    [network.serialize_pfw(create_new_p_fw()) for _ in range(POOL_SIZE)]).numpy())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "jtVVwKNrWzpt"
      },
      "outputs": [],
      "source": [
        "loop_steps_t = tf.constant(loop_steps)\n",
        "\n",
        "learning_schedule = 5e-4\n",
        "\n",
        "training_regime = TrainingRegime(\n",
        "    network, heldout_weight=1.0, hint_loss_ratio=0.7, remember_loss_ratio=None)\n",
        "\n",
        "last_step = 0\n",
        "\n",
        "# minibatch init, allowing to initialize by looking at more\n",
        "# than just one step.\n",
        "# Likewise, this can be run multiple times to improve the initialization.\n",
        "for j in range(5):\n",
        "  print(\"on\", j)\n",
        "  stats = []\n",
        "  pfw = network.init(inner_batch_size)\n",
        "\n",
        "  x_b, y_b = next(ds_iter)\n",
        "  x_b, y_b = x_b[0], y_b[0]\n",
        "  for i in range(loop_steps):\n",
        "    pfw, stats_i = network.minibatch_init(x_b[i],  y_b[i], x_b[i].shape[0], pfw=pfw)\n",
        "    stats.append(stats_i)\n",
        "  # update\n",
        "  network.update_statistics(stats, update_perc=0.5)\n",
        "\n",
        "  print(\"final mean:\")\n",
        "  for p in tf.nest.flatten(pfw):\n",
        "    print(p.shape, tf.reduce_mean(p), tf.math.reduce_std(p))\n",
        "\n",
        "\n",
        "# The outer loop here uses Adam. SGD/Momentum are more stable but way slower.\n",
        "trainer = tf.keras.optimizers.Adam(learning_schedule)\n",
        "\n",
        "loss_log = []\n",
        "def smoothen(l, lookback=20):\n",
        "  # first of all, if it's a nan, change it to a high value\n",
        "  kernel = [1./lookback] * lookback\n",
        "  return np.convolve(l[0:1] * (lookback - 1) + l, kernel, \"valid\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4cByI6hXGbHD"
      },
      "outputs": [],
      "source": [
        "!mkdir tmp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "AilvJiIj3zhS"
      },
      "outputs": [],
      "source": [
        "\n",
        "training_steps = 100000\n",
        "\n",
        "learning_schedule = 1e-4\n",
        "\n",
        "@tf.function\n",
        "def step(pfws, xts, yts, xes, yes, num_steps):\n",
        "  print(\"compiling\")\n",
        "  with tf.GradientTape() as g:\n",
        "    l, _, _ = training_regime.batch_mp_loss(pfws, xts, yts, xes, yes, num_steps)\n",
        "  all_weights = network.get_trainable_weights()\n",
        "  grads = g.gradient(l, all_weights)\n",
        "  # Try grad clipping to avoid explosions.\n",
        "  grads = [g/(tf.norm(g)+1e-8) for g in grads]\n",
        "  trainer.apply_gradients(zip(grads, all_weights))\n",
        "  return l\n",
        "\n",
        "\n",
        "import time\n",
        "start_time = time.time()\n",
        "\n",
        "for i in range(last_step + 1, last_step +1 + training_steps):\n",
        "  last_step = i\n",
        "\n",
        "  tmp_t = time.time()\n",
        "  xts, yts = next(ds_iter)\n",
        "  xes, yes = next(xval_ds_iter)\n",
        "\n",
        "  batch = pool.sample(outer_batch_size)\n",
        "  fwps = batch.ps\n",
        "\n",
        "  l = step(fwps, xts, yts, xes, yes, loop_steps_t)\n",
        "  loss_log.append(l)\n",
        "\n",
        "  if i % 50 == 0:\n",
        "    print(i)\n",
        "    print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "  if i % 500 == 0:\n",
        "    plt.plot(smoothen(loss_log, 100), label='mp')\n",
        "    plt.yscale('log')\n",
        "    plt.legend()\n",
        "    plt.show()\n",
        "  if i % 2500 == 0:\n",
        "    file_path = \"tmp/weights\"\n",
        "    network.save_weights(file_path, last_step)\n",
        "print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "\n",
        "np_batched_stl_mpv2_nograd_loss = time_series\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "0QWqVlQ5_uiK"
      },
      "outputs": [],
      "source": [
        "!ls tmp -lh"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "evt-q7nU8eC4"
      },
      "outputs": [],
      "source": [
        "def smoothen(l, lookback=20):\n",
        "  kernel = [1./lookback] * lookback\n",
        "  return onp.convolve(l[0:1] * (lookback - 1) + l, kernel, \"valid\")\n",
        "\n",
        "plt.plot(smoothen(loss_log, 100), label='mp')\n",
        "plt.yscale('log')\n",
        "plt.ylim(1e-2, 3e-1)\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "lNR0lsrVRoKn"
      },
      "outputs": [],
      "source": [
        "# check saved checkpoints\n",
        "! ls tmp -lh"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "MhNHDPzjfwYt"
      },
      "outputs": [],
      "source": [
        "#@title Optionally, load weights from savedmodel\n",
        "# override file_path as it will be useful later on.\n",
        "file_path = \"savedmodels/mnist_weights\"\n",
        "\n",
        "network = init_network(tf.sigmoid, (l0, 50))\n",
        "network.load_weights(file_path)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "D-Cuifh7wtuo"
      },
      "outputs": [],
      "source": [
        "# Now let's generate a training regime of 10 steps, and compare with SGD.\n",
        "\n",
        "eval_pfw = create_new_p_fw()\n",
        "\n",
        "def prepare_for_analysis(pfw):\n",
        "  return tf.concat([tf.reshape(t, [-1]) for t in pfw], 0).numpy()\n",
        "\n",
        "print(prepare_for_analysis(eval_pfw).shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "MZTnYwEtQ1Sq"
      },
      "outputs": [],
      "source": [
        "\n",
        "test_sp_bs = 100\n",
        "\n",
        "with tf.device(\"/cpu:0\"):\n",
        "  dataset_sp = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "  dataset_sp = dataset_sp.batch(inner_batch_size).shuffle(1000).repeat()\n",
        "  dataset_sp_iter = iter(dataset_sp)\n",
        "  test_sp_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
        "  test_sp_ds = test_sp_ds.batch(test_sp_bs).repeat()\n",
        "  test_sp_ds_iter = iter(test_sp_ds)\n",
        "\n",
        "\n",
        "def get_accuracy(pfw, xe, ye):\n",
        "  targets = tf.argmax(ye, axis=-1)\n",
        "  res, _ = network.forward(pfw, xe)\n",
        "  predictions = tf.argmax(res, axis=-1)\n",
        "\n",
        "  tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))\n",
        "  accuracy = tot_correct / ye.shape[0]\n",
        "\n",
        "  return accuracy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vPYIYhBw93AZ"
      },
      "outputs": [],
      "source": [
        "# MP step accuracy\n",
        "all_MP_series = []\n",
        "eval_tot_steps = 100\n",
        "ev_losses = np.zeros([eval_tot_steps, loop_steps])\n",
        "for s in range(eval_tot_steps):\n",
        "  if s % 10 == 0:\n",
        "    print(\"\\nRepetition {}\".format(s))\n",
        "  MP_params_series = []\n",
        "\n",
        "  mp_pfw = network.init(inner_batch_size)\n",
        "\n",
        "\n",
        "  for i in range(loop_steps):\n",
        "    # using test to train too, because we want to see the effect of learning\n",
        "    # on the parameter space.\n",
        "    xt, yt = next(dataset_sp_iter)\n",
        "    \n",
        "    mp_pfw, _= network.inner_update(mp_pfw, xt, yt)\n",
        "\n",
        "\n",
        "    xe, ye = next(test_sp_ds_iter)\n",
        "    accuracy = get_accuracy(mp_pfw, xe, ye)\n",
        "    ev_losses[s, i] = accuracy\n",
        "\n",
        "ev_losses_m = np.mean(ev_losses, axis=0)\n",
        "ev_losses_sd = np.std(ev_losses, axis=0)\n",
        "print(\"mean:\", ev_losses_m)\n",
        "print(\"sd:\", ev_losses_sd)\n",
        "\n",
        "ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)\n",
        "plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')\n",
        "plt.xlabel(\"num steps\")\n",
        "plt.ylabel(\"Accuracy\")\n",
        "plt.legend()\n",
        "\n",
        "mp_baseline_m = ev_losses_m"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "PGlLBdbhBPmu"
      },
      "outputs": [],
      "source": [
        "# ADAM network\n",
        "adam_network = MPNetwork([MPDense(50, stateful=False),\n",
        "                  MPActivation(tf.sigmoid, stateful=False),\n",
        "                  MPDense(10, stateful=False),\n",
        "                  MPSoftmax(stateful=False),\n",
        "                  ],\n",
        "                  MPCrossEntropyLoss(message_size, stateful=False))\n",
        "adam_network.setup(in_dim=l0, message_size=message_size)\n",
        "\n",
        "\n",
        "\n",
        "def get_adam_accuracy(pfw, xe, ye):\n",
        "  targets = tf.argmax(ye, axis=-1)\n",
        "  res, _ = adam_network.forward(pfw, xe)\n",
        "  predictions = tf.argmax(res, axis=-1)\n",
        "\n",
        "  tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))\n",
        "  accuracy = tot_correct / ye.shape[0]\n",
        "\n",
        "  return accuracy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "OC8Zszt88fIx"
      },
      "outputs": [],
      "source": [
        "# SGD step\n",
        "\n",
        "all_SGD_series = []\n",
        "eval_tot_steps = 100\n",
        "sgd_ev_losses = np.zeros([eval_tot_steps, loop_steps])\n",
        "for s in range(eval_tot_steps):\n",
        "  if s % 10 == 0:\n",
        "    print(\"\\nRepetition {}\".format(s))\n",
        "  SGD_params_series = []\n",
        "\n",
        "  sgd_pfw = [tf.Variable(t) for t in adam_network.init()]\n",
        "\n",
        "  adam_trainer = tf.keras.optimizers.SGD(0.1)\n",
        "\n",
        "  @tf.function\n",
        "  def step(xt, yt):\n",
        "    with tf.GradientTape() as g:\n",
        "      g.watch(sgd_pfw)\n",
        "      y, _ = adam_network.forward(sgd_pfw, xt)\n",
        "      l, _ = adam_network.compute_loss(y, yt)\n",
        "      l = tf.reduce_mean(tf.reduce_sum(l, axis=[1]))\n",
        "\n",
        "    grads = g.gradient(l, sgd_pfw)\n",
        "    adam_trainer.apply_gradients(zip(grads, sgd_pfw))\n",
        "\n",
        "\n",
        "  for i in range(loop_steps):\n",
        "    # using test to train too, because we want to see the effect of learning\n",
        "    # on the parameter space.\n",
        "    xt, yt = next(dataset_sp_iter)\n",
        "    step(xt, yt)\n",
        "\n",
        "    if i % 1 == 0:\n",
        "      xe, ye = next(test_sp_ds_iter)\n",
        "      accuracy = get_adam_accuracy(sgd_pfw, xe, ye)\n",
        "      sgd_ev_losses[s, i] = accuracy\n",
        "\n",
        "ev_losses_m = np.mean(sgd_ev_losses, axis=0)\n",
        "ev_losses_sd = np.std(sgd_ev_losses, axis=0)\n",
        "print(\"mean:\", ev_losses_m)\n",
        "print(\"sd:\", ev_losses_sd)\n",
        "\n",
        "ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)\n",
        "plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')\n",
        "#plt.ylim(0.0, 0.04)\n",
        "plt.xlabel(\"num steps\")\n",
        "plt.ylabel(\"Accuracy\")\n",
        "plt.legend()\n",
        "\n",
        "sgd_m = ev_losses_m"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "wOWy2FaUGARI"
      },
      "outputs": [],
      "source": [
        "# Adam step\n",
        "\n",
        "all_SGD_series = []\n",
        "eval_tot_steps = 100\n",
        "sgd_ev_losses = np.zeros([eval_tot_steps, loop_steps])\n",
        "for s in range(eval_tot_steps):\n",
        "  if s % 10 == 0:\n",
        "    print(\"\\nRepetition {}\".format(s))\n",
        "  SGD_params_series = []\n",
        "\n",
        "  sgd_pfw = [tf.Variable(t) for t in adam_network.init()]\n",
        "\n",
        "\n",
        "  adam_trainer = tf.keras.optimizers.Adam(0.01)\n",
        "\n",
        "  @tf.function\n",
        "  def step(xt, yt):\n",
        "    with tf.GradientTape() as g:\n",
        "      g.watch(sgd_pfw)\n",
        "      y, _ = adam_network.forward(sgd_pfw, xt)\n",
        "      l, _ = adam_network.compute_loss(y, yt)\n",
        "      l = tf.reduce_mean(tf.reduce_sum(l, axis=[1]))\n",
        "\n",
        "    grads = g.gradient(l, sgd_pfw)\n",
        "    adam_trainer.apply_gradients(zip(grads, sgd_pfw))\n",
        "\n",
        "\n",
        "  for i in range(loop_steps):\n",
        "    # using test to train too, because we want to see the effect of learning\n",
        "    # on the parameter space.\n",
        "    xt, yt = next(dataset_sp_iter)\n",
        "    step(xt, yt)\n",
        "\n",
        "    if i % 1 == 0:\n",
        "      xe, ye = next(test_sp_ds_iter)\n",
        "      accuracy = get_adam_accuracy(sgd_pfw, xe, ye)\n",
        "      sgd_ev_losses[s, i] = accuracy\n",
        "\n",
        "ev_losses_m = np.mean(sgd_ev_losses, axis=0)\n",
        "ev_losses_sd = np.std(sgd_ev_losses, axis=0)\n",
        "print(\"mean:\", ev_losses_m)\n",
        "print(\"sd:\", ev_losses_sd)\n",
        "\n",
        "ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)\n",
        "plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')\n",
        "plt.xlabel(\"num steps\")\n",
        "plt.ylabel(\"Accuracy\")\n",
        "plt.legend()\n",
        "\n",
        "adam_m = ev_losses_m"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "09WKw-CIZFmR"
      },
      "source": [
        "# See what happens if you change activation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "oRBNnC3GZQSC"
      },
      "outputs": [],
      "source": [
        "\n",
        "new_activation = lambda x: tf.maximum(0.0, tf.sign(x))\n",
        "\n",
        "l0 = PIC_L * PIC_L\n",
        "\n",
        "new_network = init_network(new_activation, (l0, 50))\n",
        "new_network.load_weights(file_path)\n",
        "\n",
        "def get_accuracy(net, pfw, xe, ye):\n",
        "  targets = tf.argmax(ye, axis=-1)\n",
        "  res, _ = net.forward(pfw, xe)\n",
        "  predictions = tf.argmax(res, axis=-1)\n",
        "\n",
        "  tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))\n",
        "  accuracy = tot_correct / ye.shape[0]\n",
        "\n",
        "  return accuracy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "X8crDt9UZfzU"
      },
      "outputs": [],
      "source": [
        "# MP step accuracy\n",
        "all_MP_series = []\n",
        "eval_tot_steps = 100\n",
        "ev_losses = np.zeros([eval_tot_steps, loop_steps])\n",
        "for s in range(eval_tot_steps):\n",
        "  if s % 10 == 0:\n",
        "    print(\"\\nRepetition {}\".format(s))\n",
        "  MP_params_series = []\n",
        "\n",
        "  mp_pfw = new_network.init(inner_batch_size)\n",
        "\n",
        "  for i in range(loop_steps):\n",
        "    # using test to train too, because we want to see the effect of learning\n",
        "    # on the parameter space.\n",
        "    xt, yt = next(dataset_sp_iter)\n",
        "    \n",
        "    mp_pfw, _ = new_network.inner_update(mp_pfw, xt, yt)\n",
        "\n",
        "    xe, ye = next(test_sp_ds_iter)\n",
        "    accuracy = get_accuracy(new_network, mp_pfw, xe, ye)\n",
        "    ev_losses[s, i] = accuracy\n",
        "\n",
        "ev_losses_m = np.mean(ev_losses, axis=0)\n",
        "ev_losses_sd = np.std(ev_losses, axis=0)\n",
        "print(\"mean:\", ev_losses_m)\n",
        "print(\"sd:\", ev_losses_sd)\n",
        "\n",
        "ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)\n",
        "plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')\n",
        "#plt.ylim(0.0, 0.04)\n",
        "plt.xlabel(\"num steps\")\n",
        "plt.ylabel(\"Accuracy\")\n",
        "plt.legend()\n",
        "\n",
        "mp_stepf_m = ev_losses_m"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "H4IsIapATTYv"
      },
      "source": [
        "# See what happens if you transfer the learned parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "RWnqnkoSTjCH"
      },
      "outputs": [],
      "source": [
        "# big version of MNIST\n",
        "(x_train_b, y_train_b), (x_test_b, y_test_b) = load_mnist()\n",
        "y_train_b = one_hottify(y_train_b)\n",
        "y_test_b = one_hottify(y_test_b)\n",
        "\n",
        "# standardize inputs:\n",
        "train_mean, train_std = x_train_b.mean(), x_train_b.std()\n",
        "\n",
        "x_train_b = (x_train_b - train_mean) / train_std\n",
        "x_test_b = (x_test_b - train_mean) / train_std\n",
        "\n",
        "x_train_b = x_train_b.reshape([-1, 28 * 28])\n",
        "x_test_b = x_test_b.reshape([-1, 28* 28])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "166Gjk3OTZ5w"
      },
      "outputs": [],
      "source": [
        "new_network = init_network(tf.sigmoid, (28*28, 100))\n",
        "new_network.load_weights(file_path)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1beVZtck3TOh"
      },
      "outputs": [],
      "source": [
        "!ls tmp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4NbxAZDaUr3c"
      },
      "outputs": [],
      "source": [
        "\n",
        "test_sp_bs = 100\n",
        "\n",
        "with tf.device(\"/cpu:0\"):\n",
        "  dataset_sp_b = tf.data.Dataset.from_tensor_slices((x_train_b, y_train_b))\n",
        "  dataset_sp_b = dataset_sp_b.batch(inner_batch_size).shuffle(1000).repeat()\n",
        "  dataset_sp_b_iter = iter(dataset_sp_b)\n",
        "  test_sp_b_ds = tf.data.Dataset.from_tensor_slices((x_test_b, y_test_b))\n",
        "  test_sp_b_ds = test_sp_b_ds.batch(test_sp_bs).repeat()\n",
        "  test_sp_b_ds_iter = iter(test_sp_b_ds)\n",
        "\n",
        "\n",
        "def get_accuracy(pfw, xe, ye):\n",
        "  targets = tf.argmax(ye, axis=-1)\n",
        "  res, _ = new_network.forward(pfw, xe)\n",
        "  predictions = tf.argmax(res, axis=-1)\n",
        "\n",
        "  tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))\n",
        "  accuracy = tot_correct / ye.shape[0]\n",
        "\n",
        "  return accuracy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xgw22TAvUr3g"
      },
      "outputs": [],
      "source": [
        "# MP step accuracy\n",
        "# beware, this is very slow.\n",
        "all_MP_series = []\n",
        "eval_tot_steps = 100\n",
        "ev_losses = np.zeros([eval_tot_steps, loop_steps])\n",
        "for s in range(eval_tot_steps):\n",
        "  if s % 10 == 0:\n",
        "    print(\"\\nRepetition {}\".format(s))\n",
        "  MP_params_series = []\n",
        "\n",
        "  mp_pfw = new_network.init(inner_batch_size)\n",
        "\n",
        "  for i in range(loop_steps):\n",
        "    # using test to train too, because we want to see the effect of learning\n",
        "    # on the parameter space.\n",
        "    xt, yt = next(dataset_sp_b_iter)\n",
        "    \n",
        "    mp_pfw, _ = new_network.inner_update(mp_pfw, xt, yt)\n",
        "\n",
        "    xe, ye = next(test_sp_b_ds_iter)\n",
        "    accuracy = get_accuracy(mp_pfw, xe, ye)\n",
        "    ev_losses[s, i] = accuracy\n",
        "\n",
        "ev_losses_m = np.mean(ev_losses, axis=0)\n",
        "ev_losses_sd = np.std(ev_losses, axis=0)\n",
        "print(\"mean:\", ev_losses_m)\n",
        "print(\"sd:\", ev_losses_sd)\n",
        "\n",
        "ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)\n",
        "plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')\n",
        "plt.xlabel(\"num steps\")\n",
        "plt.ylabel(\"Accuracy\")\n",
        "plt.legend()\n",
        "\n",
        "mp_big_net_m = ev_losses_m"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "x9vtHxiOUVgU"
      },
      "outputs": [],
      "source": [
        "# plot all together.\n",
        "\n",
        "x_values = range(1, len(mp_baseline_m) + 1)\n",
        "\n",
        "plt.plot(x_values, mp_baseline_m * 100, label='MPLP trained')\n",
        "plt.plot(x_values, mp_stepf_m * 100, label='MPLP step function')\n",
        "plt.plot(x_values, mp_big_net_m * 100, label='MPLP on bigger network')\n",
        "plt.plot(x_values, sgd_m * 100, label='SGD')\n",
        "plt.plot(x_values, adam_m * 100, label='Adam')\n",
        "plt.xticks([4, 8, 12, 16, 20])\n",
        "\n",
        "plt.xlabel(\"Number of steps\")\n",
        "plt.ylabel(\"Accuracy (%)\")\n",
        "plt.legend()\n",
        "\n",
        "\n",
        "with open(\"tmp/mplp_evals.png\", \"wb\") as fout:\n",
        "  plt.savefig(fout)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//experimental/selforg:selforg_colab",
        "kind": "private"
      },
      "name": "mplp_MNIST.ipynb",
      "provenance": [
        {
          "file_id": "124kqo8z51uqbXZHgLDsw-gJV6ySwnccb",
          "timestamp": 1593621699990
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/self_organising_systems/mplp/notebooks/sharednet_ms_4_L2L_MNIST.ipynb?workspaceId=etr:twp::citc",
          "timestamp": 1593614608800
        },
        {
          "file_id": "1CfiYwYxmArmYc6xgyqH9S9_uDtQCQ_xA",
          "timestamp": 1593437800424
        },
        {
          "file_id": "14Gz-Rpnog-o4kk2hy_FltL6V7TzBGhZ3",
          "timestamp": 1589037587398
        },
        {
          "file_id": "16gXye-iMJYHdW8p43ibkRfCKr9l3Tuxv",
          "timestamp": 1589027248195
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
